import os
import sys
import json
import numpy as np
import matplotlib.pyplot as plt
import scipy as sp
import pandas as pd
import torch
import torch.nn as nn
import torch.nn.functional as F
import lpips

from torchvision import transforms
from PIL import Image
from transformers import CLIPProcessor, CLIPModel
from torchvision.models.feature_extraction import create_feature_extractor, get_graph_node_names
from skimage.color import rgb2gray
from skimage.metrics import structural_similarity as ssim


def preprocess_image(img):
    transform = transforms.Compose([
        transforms.Resize((256, 256)),  # Resize to match LPIPS input size
        transforms.ToTensor(),  # Convert to tensor
        transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])  # Normalize between [-1, 1]
    ])
    return transform(img).unsqueeze(0)



def run_metric(image_batch_pred, image_batch_gt):
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    print("device:", device)


    # PixCorr
    preprocess = transforms.Compose([
        transforms.Resize(425, interpolation=transforms.InterpolationMode.BILINEAR),
    ])

    all_images_flattened = preprocess(image_batch_pred).reshape(len(image_batch_pred), -1).cpu()
    all_images_gt_flattened = preprocess(image_batch_gt).view(len(image_batch_gt), -1).cpu()

    print(all_images_flattened.shape)
    print(all_images_gt_flattened.shape)

    corrsum = 0
    for i in range(len(all_images_flattened)):
        corrsum += np.corrcoef(all_images_flattened[i], all_images_gt_flattened[i])[0][1]
    corrmean = corrsum / len(image_batch_pred)

    pixcorr = corrmean
    print("pixcorr: ", pixcorr)

    # LPIPS
    lpips_model = lpips.LPIPS(net='alex')  # Options: 'alex', 'vgg', 'squeeze'

    batch_size = image_batch_pred.shape[0]
    lpips_scores = []

    for i in range(batch_size):
        score = lpips_model(image_batch_pred[i].unsqueeze(0), image_batch_gt[i].unsqueeze(0))
        lpips_scores.append(score.item())

    lpips_score = sum(lpips_scores) / batch_size

    print(f"LPIPS Score: {lpips_score:.4f}")  # Lower is better



    preprocess = transforms.Compose([
        transforms.Resize(425, interpolation=transforms.InterpolationMode.BILINEAR),
    ])

    # convert image to grayscale with rgb2grey
    img_gray = rgb2gray(preprocess(image_batch_pred).permute((0, 2, 3, 1)).cpu())
    recon_gray = rgb2gray(preprocess(image_batch_gt).permute((0, 2, 3, 1)).cpu())
    print("converted, now calculating ssim...")

    ssim_score = []
    for im, rec in zip(img_gray, recon_gray):
        ssim_score.append(
            ssim(rec, im, multichannel=True, gaussian_weights=True, sigma=1.5, use_sample_covariance=False,
                 data_range=1.0))

    ssim_res = np.mean(ssim_score)
    print("ssim: ", ssim_res)

    # alex net
    from torchvision.models import alexnet, AlexNet_Weights
    alex_weights = AlexNet_Weights.IMAGENET1K_V1

    alex_model = create_feature_extractor(alexnet(weights=alex_weights), return_nodes=['features.4', 'features.11']).to(
        device)
    alex_model.eval().requires_grad_(False)

    # see alex_weights.transforms()
    preprocess = transforms.Compose([
        transforms.Resize(256, interpolation=transforms.InterpolationMode.BILINEAR),
        transforms.Normalize(mean=[0.485, 0.456, 0.406],
                             std=[0.229, 0.224, 0.225]),
    ])

    layer = 'early, AlexNet(2)'
    print(f"\n---{layer}---")
    all_per_correct = two_way_identification(image_batch_pred.to(device).float(), image_batch_gt,
                                             alex_model, preprocess, 'features.4')
    alexnet2 = np.mean(all_per_correct)
    print(f"2-way Percent Correct: {alexnet2:.4f}")

    layer = 'mid, AlexNet(5)'
    print(f"\n---{layer}---")
    all_per_correct = two_way_identification(image_batch_pred.to(device).float(), image_batch_gt,
                                             alex_model, preprocess, 'features.11')
    alexnet5 = np.mean(all_per_correct)
    print(f"2-way Percent Correct: {alexnet5:.4f}")



    clip_model = CLIPModel.from_pretrained("openai/clip-vit-large-patch14")
    processor = CLIPProcessor.from_pretrained("openai/clip-vit-large-patch14")
    clip_model = clip_model.to(device)
    preprocess = transforms.Compose([
        transforms.Resize(224, interpolation=transforms.InterpolationMode.BILINEAR),
        transforms.Normalize(mean=[0.48145466, 0.4578275, 0.40821073],
                             std=[0.26862954, 0.26130258, 0.27577711]),
    ])

    all_per_correct = two_way_identification(image_batch_pred.to(device), image_batch_gt.to(device),
                                             clip_model.get_image_features, preprocess, None)  # final layer
    clip_ = np.mean(all_per_correct)
    print(f"CLIP 2-way Percent Correct: {clip_:.4f}")

    # Efficient Net
    from torchvision.models import efficientnet_b1, EfficientNet_B1_Weights
    weights = EfficientNet_B1_Weights.DEFAULT
    eff_model = create_feature_extractor(efficientnet_b1(weights=weights),
                                         return_nodes=['avgpool']).to(device)
    eff_model.eval().requires_grad_(False)

    # see weights.transforms()
    preprocess = transforms.Compose([
        transforms.Resize(255, interpolation=transforms.InterpolationMode.BILINEAR),
        transforms.Normalize(mean=[0.485, 0.456, 0.406],
                             std=[0.229, 0.224, 0.225]),
    ])

    gt = eff_model(preprocess(image_batch_gt.to(device)))['avgpool']
    gt = gt.reshape(len(gt), -1).cpu().numpy()
    fake = eff_model(preprocess(image_batch_pred.to(device)))['avgpool']
    fake = fake.reshape(len(fake), -1).cpu().numpy()

    effnet = np.array([sp.spatial.distance.correlation(gt[i], fake[i]) for i in range(len(gt))]).mean()
    print("Effnet Distance:", effnet)

    # Swav
    swav_model = torch.hub.load('facebookresearch/swav:main', 'resnet50')
    swav_model = create_feature_extractor(swav_model,
                                          return_nodes=['avgpool']).to(device)
    swav_model.eval().requires_grad_(False)

    preprocess = transforms.Compose([
        transforms.Resize(224, interpolation=transforms.InterpolationMode.BILINEAR),
        transforms.Normalize(mean=[0.485, 0.456, 0.406],
                             std=[0.229, 0.224, 0.225]),
    ])

    gt = swav_model(preprocess(image_batch_gt.to(device)))['avgpool']
    gt = gt.reshape(len(gt), -1).cpu().numpy()
    fake = swav_model(preprocess(image_batch_pred.to(device)))['avgpool']
    fake = fake.reshape(len(fake), -1).cpu().numpy()

    swav = np.array([sp.spatial.distance.correlation(gt[i], fake[i]) for i in range(len(gt))]).mean()
    print("Swav Distance:", swav)
    # InceptionV3
    from torchvision.models import inception_v3, Inception_V3_Weights

    weights = Inception_V3_Weights.DEFAULT
    inception_model = create_feature_extractor(inception_v3(weights=weights),
                                               return_nodes=['avgpool']).to(device)
    inception_model.eval().requires_grad_(False)

    # see weights.transforms()
    preprocess = transforms.Compose([
        transforms.Resize(342, interpolation=transforms.InterpolationMode.BILINEAR),
        transforms.Normalize(mean=[0.485, 0.456, 0.406],
                             std=[0.229, 0.224, 0.225]),
    ])

    all_per_correct = two_way_identification(image_batch_pred.to(device), image_batch_gt.to(device),
                                             inception_model, preprocess, 'avgpool')

    inception = np.mean(all_per_correct)
    print(f"InceptionV3 2-way Percent Correct: {inception:.4f}")


@torch.no_grad()
def two_way_identification(image_batch_pred, image_batch_gt, model, preprocess, feature_layer=None, return_avg=True):
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    preds = model(torch.stack([preprocess(recon) for recon in image_batch_pred], dim=0).to(device))
    reals = model(torch.stack([preprocess(indiv) for indiv in image_batch_gt], dim=0).to(device))
    if feature_layer is None:
        preds = preds.float().flatten(1).cpu().numpy()
        reals = reals.float().flatten(1).cpu().numpy()
    else:
        preds = preds[feature_layer].float().flatten(1).cpu().numpy()
        reals = reals[feature_layer].float().flatten(1).cpu().numpy()

    r = np.corrcoef(reals, preds)
    r = r[:len(image_batch_gt), len(image_batch_gt):]
    congruents = np.diag(r)

    success = r < congruents
    success_cnt = np.sum(success, 0)

    if return_avg:
        perf = np.mean(success_cnt) / (len(image_batch_gt)-1)
        return perf
    else:
        return success_cnt, len(image_batch_gt)-1


def load_all_imgs(image_dir):
    transform = transforms.Compose([
        transforms.Resize((256, 256)),
        transforms.ToTensor()
    ])

    image_tensors = []

    for filename in os.listdir(image_dir):
        if filename.endswith('.jpg'):
            image_path = os.path.join(image_dir, filename)
            image = Image.open(image_path).convert('RGB')
            tensor_image = transform(image)
            image_tensors.append(tensor_image)

    image_batch = torch.stack(image_tensors)  # (N, C, H, W)
    print(image_batch.shape)
    return image_batch


save_dir = "saved_img"
image_god_pred = load_all_imgs(f"{save_dir}/pred")
image_god_gt = load_all_imgs(f"{save_dir}/gt")


run_metric(image_god_pred, image_god_gt)

